{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PyTorch Tutorial 03: Define Networks\n",
    "## Overview\n",
    "In this tutorial, we explain how to define networks. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0.1283, 0.7627, 0.0181,  ..., 0.8569, 0.4531, 0.4337],\n",
       "         [0.3366, 0.5344, 0.2687,  ..., 0.2099, 0.7453, 0.0209],\n",
       "         [0.4089, 0.6110, 0.5130,  ..., 0.1261, 0.9310, 0.4383],\n",
       "         ...,\n",
       "         [0.2976, 0.0961, 0.5546,  ..., 0.8212, 0.6335, 0.3497],\n",
       "         [0.3701, 0.0833, 0.4147,  ..., 0.2035, 0.4569, 0.4486],\n",
       "         [0.2438, 0.6422, 0.8251,  ..., 0.3859, 0.9151, 0.4302]],\n",
       "\n",
       "        [[0.2359, 0.1646, 0.9636,  ..., 0.1148, 0.2140, 0.5944],\n",
       "         [0.3071, 0.6690, 0.1750,  ..., 0.1503, 0.1939, 0.9889],\n",
       "         [0.4855, 0.2156, 0.6161,  ..., 0.2892, 0.8562, 0.4031],\n",
       "         ...,\n",
       "         [0.7569, 0.7194, 0.5767,  ..., 0.4001, 0.1561, 0.0242],\n",
       "         [0.7725, 0.8491, 0.1711,  ..., 0.9211, 0.9180, 0.9002],\n",
       "         [0.8873, 0.2162, 0.2717,  ..., 0.4897, 0.5422, 0.9277]],\n",
       "\n",
       "        [[0.2835, 0.1043, 0.4089,  ..., 0.8274, 0.1067, 0.7595],\n",
       "         [0.4789, 0.1006, 0.1152,  ..., 0.2405, 0.4191, 0.4407],\n",
       "         [0.3027, 0.0196, 0.8439,  ..., 0.3833, 0.7521, 0.3324],\n",
       "         ...,\n",
       "         [0.2466, 0.9571, 0.2300,  ..., 0.2305, 0.4215, 0.6862],\n",
       "         [0.9049, 0.5634, 0.8508,  ..., 0.4987, 0.4500, 0.1857],\n",
       "         [0.9792, 0.3026, 0.3808,  ..., 0.7751, 0.9458, 0.5097]],\n",
       "\n",
       "        ...,\n",
       "\n",
       "        [[0.8145, 0.0907, 0.8197,  ..., 0.7462, 0.0697, 0.4390],\n",
       "         [0.5846, 0.9178, 0.8144,  ..., 0.5008, 0.2567, 0.2260],\n",
       "         [0.6274, 0.2819, 0.8423,  ..., 0.8771, 0.2793, 0.0158],\n",
       "         ...,\n",
       "         [0.4975, 0.7932, 0.9046,  ..., 0.5415, 0.4966, 0.2313],\n",
       "         [0.5214, 0.9734, 0.9920,  ..., 0.7577, 0.0980, 0.3426],\n",
       "         [0.1179, 0.9615, 0.2107,  ..., 0.8465, 0.2767, 0.4508]],\n",
       "\n",
       "        [[0.7727, 0.1136, 0.1143,  ..., 0.0377, 0.4342, 0.7612],\n",
       "         [0.2998, 0.0757, 0.6833,  ..., 0.0286, 0.2695, 0.4993],\n",
       "         [0.9869, 0.3836, 0.9676,  ..., 0.9934, 0.1146, 0.5886],\n",
       "         ...,\n",
       "         [0.3846, 0.9321, 0.9261,  ..., 0.4434, 0.8544, 0.8308],\n",
       "         [0.6225, 0.3738, 0.0256,  ..., 0.4857, 0.3781, 0.5225],\n",
       "         [0.5433, 0.8492, 0.5744,  ..., 0.7890, 0.5905, 0.3064]],\n",
       "\n",
       "        [[0.0022, 0.4302, 0.2720,  ..., 0.0330, 0.1630, 0.4703],\n",
       "         [0.5033, 0.4198, 0.5062,  ..., 0.3438, 0.3919, 0.8449],\n",
       "         [0.1435, 0.7137, 0.2919,  ..., 0.0269, 0.0864, 0.2255],\n",
       "         ...,\n",
       "         [0.4896, 0.9618, 0.2899,  ..., 0.8602, 0.7582, 0.3978],\n",
       "         [0.3311, 0.2050, 0.2950,  ..., 0.9023, 0.5985, 0.8504],\n",
       "         [0.8554, 0.9793, 0.7188,  ..., 0.3301, 0.1498, 0.5337]]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand([8, 100, 10]).detach()\n",
    "x "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 1, 0, 1, 1, 0, 0, 1], dtype=torch.int32)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = torch.rand(8)\n",
    "y = (y>0.5).int()\n",
    "y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MLP, self).__init__()\n",
    "        self.first_layer = nn.Linear(1000,50)\n",
    "        self.second_layer = nn.Linear(50, 1)\n",
    "    def forward(self, x):\n",
    "        x = torch.flatten(x, start_dim=1, end_dim=2)\n",
    "        x = nn.functional.relu(self.first_layer(x))\n",
    "        x = self.second_layer(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlp = MLP()\n",
    "output = mlp(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.1030],\n",
       "        [-0.2689],\n",
       "        [-0.1698],\n",
       "        [-0.1140],\n",
       "        [-0.3090],\n",
       "        [-0.1686],\n",
       "        [-0.2050],\n",
       "        [-0.1751]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Embedding(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Embedding, self).__init__()\n",
    "        self.embedding = nn.Embedding(4, 100)\n",
    "    def forward(self, x):\n",
    "        return self.embedding(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding = Embedding()\n",
    "\n",
    "embedding_input = torch.tensor([[0,1, 0],[2,3, 3]])\n",
    "embedding_output = embedding(embedding_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 3, 100])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LSTM, self).__init__()\n",
    "        self.lstm = nn.LSTM(10, \n",
    "                           15, \n",
    "                           num_layers=2, \n",
    "                           bidirectional=True, \n",
    "                           dropout=0.1)\n",
    "    def forward(self, x):\n",
    "        output, (hidden, cell) = self.lstm(x)\n",
    "        return output, hidden, cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "permute_x = x.permute([1,0,2])\n",
    "lstm = LSTM()\n",
    "output_lstm1, output_lstm2, output_lstm3 = lstm(permute_x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 8, 15])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_lstm2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Conv(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Conv, self).__init__()\n",
    "        self.conv1d = nn.Conv1d(100, 50, 2)\n",
    "    def forward(self, x):\n",
    "        return self.conv1d(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv = Conv()\n",
    "output = conv(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 50, 9])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output.shape"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
